import math
import oneflow as torch
import oneflow.nn as nn


class PositionalEncoding(nn.Module):
    """Positional encoding."""

    def __init__(self, d_model, dropout_rate=0.0, max_len=5000):
        """Initialize class.

        :param int d_model: embedding dim
        :param float dropout_rate: dropout rate
        :param int max_len: maximum input length

        """
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model
        self.xscale = math.sqrt(self.d_model)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.pe = None
        self.extend_pe(torch.tensor(0.0).expand(1, max_len))

    def extend_pe(self, x):
        """Reset the positional encodings."""
        if self.pe is not None:
            if self.pe.size(1) >= x.size(1):
                if self.pe.dtype != x.dtype or self.pe.device != x.device:
                    self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                return
        pe = torch.zeros(x.size(1), self.d_model)
        position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32) * 
                             -(math.log(10000.0) / self.d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.pe = pe.to(device=x.device, dtype=x.dtype)

    def forward(self, x: torch.Tensor):
        """Add positional encoding.

        Args:
            x (torch.Tensor): Input. Its shape is (batch, time, ...)

        Returns:
            torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)

        """
        self.extend_pe(x)
        x = x * self.xscale + self.pe[:, :x.size(1)]
        return self.dropout(x)

    def inference(self, x, startid=0):
        self.extend_pe(x)
        x = x * self.xscale + self.pe[:, startid:startid+x.size(1)]
        return x, None


class ScaledPositionalEncoding(PositionalEncoding):
    """Scaled positional encoding module.

    See also: Sec. 3.2  https://arxiv.org/pdf/1809.08895.pdf

    """

    def __init__(self, d_model, dropout_rate=0.0, max_len=5000):
        """Initialize class.

        :param int d_model: embedding dim
        :param float dropout_rate: dropout rate
        :param int max_len: maximum input length

        """
        super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
        self.alpha = nn.Parameter(torch.tensor(1.0))

    def reset_parameters(self):
        """Reset parameters."""
        self.alpha.data = torch.tensor(1.0)

    def forward(self, x):
        """Add positional encoding.

        Args:
            x (torch.Tensor): Input. Its shape is (batch, time, ...)

        Returns:
            torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)

        """
        self.extend_pe(x)
        x = x + self.alpha * self.pe[:, :x.size(1)]
        return self.dropout(x)

    def inference(self, x, startid=0):
        self.extend_pe(x)
        x = x + self.alpha * self.pe[:, startid:startid+x.size(1)]
        return x, None  


class MixedPositionalEncoding(PositionalEncoding):
    """Mixed Scaled positional encoding module.

        Two Modes:
            Default scale and learnable scale!

    """

    def __init__(self, d_model, dropout_rate=0.0, max_len=5000, scale_learnable=False):
        """Initialize class.

        :param int d_model: embedding dim
        :param float dropout_rate: dropout rate
        :param int max_len: maximum input length

        """
        super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)

        self.scale_learnable = scale_learnable

        if self.scale_learnable:
            self.alpha = nn.Parameter(torch.tensor(1.0))
           
    def reset_parameters(self):
        """Reset parameters."""
        self.alpha.data = torch.tensor(1.0)

    def forward(self, x):
        """Add positional encoding.

        Args:
            x (torch.Tensor): Input. Its shape is (batch, time, ...)

        Returns:
            torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)

        """
        self.extend_pe(x)
        if self.scale_learnable:
            x = x + self.alpha * self.pe[:, :x.size(1)]
        else:
            x = x * self.xscale + self.pe[:, :x.size(1)]
        return self.dropout(x), None

    def inference(self, x, startid=0):
        self.extend_pe(x)
        if self.scale_learnable:
            x = x + self.alpha * self.pe[:, startid:startid+x.size(1)]
        else:
            x = x * self.xscale + self.pe[:, startid:startid+x.size(1)]
        return x, None


class RelPositionalEncoding(nn.Module):
    """Positional encoding."""
    def __init__(self, emb_dim, scale_learnable=False, dropout=0.0):
        """Initialize class.

        :param int d_model: embedding dim
        :param float dropout_rate: dropout rate
        :param int max_len: maximum input length

        """
        super(RelPositionalEncoding, self).__init__()
        self.emb_dim = emb_dim
        self.xscale = math.sqrt(self.emb_dim)
        self.dropout = nn.Dropout(p=dropout)
        self.scale_learnable = scale_learnable

        if self.scale_learnable:
            self.alpha = nn.Parameter(torch.tensor(1.0))

    def _embedding_from_positions(self, position):
        """get absolute pos embedding based position.
        Args:
            position (torch.Tensor): Input. Its shape is (b, t)
        Returns:
            posemb (torch.Tensor): Encoded tensor. Its shape is (b, time, emb_dim)
        """
        batch_size, time_step = position.size()
        posemb = torch.zeros(batch_size, time_step, self.emb_dim, device=position.device)
        div_term = torch.exp(torch.arange(0, self.emb_dim, 2, device=position.device, dtype=torch.float32) * -(math.log(10000.0) / self.emb_dim))
        posemb[:, :, 0::2] = torch.sin(position.float().unsqueeze(-1) * div_term)
        posemb[:, :, 1::2] = torch.cos(position.float().unsqueeze(-1) * div_term)
        return posemb

    def forward(self, x: torch.Tensor):
        """Add positional encoding.
        Args:
            x (torch.Tensor): Input. Its shape is (batch, time, ...)
        Returns:
            torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
        """
        pos = torch.arange(0, x.size(1), device=x.device).reshape(1, -1) # [1, t]
        posemb = self._embedding_from_positions(pos)  # [1, t, emb_dim]
        if self.scale_learnable:
            x = x + self.alpha * posemb
        else:
            x = x * self.xscale + posemb
        return self.dropout(x), posemb

    def forward_from_pos(self, x, pos):
        """Add positional encoding.
        Args:
            x (torch.Tensor): Input. Its shape is (b, t, emb)
            pos (torch.Tensor), Its shape is (b, t)
        Returns:
            torch.Tensor: Encoded tensor. Its shape is (b, t, ...)
        """
        posemb = self._embedding_from_positions(pos)  # [b, t, emb_dim]
        if self.scale_learnable:
            x = x + self.alpha * posemb
        else:
            x = x * self.xscale + posemb
        return x, posemb

    def inference(self, x, step):
        raise NotImplementedError